Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix peft multi-gpu issue #145

Merged
merged 7 commits into from
Mar 14, 2023
Merged

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

This PR fixes peft + LoRA multi-gpu issue, to reproduce, run:

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{free_in_GB-2}GB'

n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m", 
    device_map='auto',
    max_memory=max_memory
)

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-6.7b")

print(model)

for param in model.parameters():
  param.requires_grad = False  # freeze the model - train adapters later
  if param.ndim == 1:
    # cast the small parameters (e.g. layernorm) to fp32 for stability
    param.data = param.data.to(torch.float32)

#model.gradient_checkpointing_enable()  # reduce number of stored activations
#model.model.decoder.project_in = lambda x: x.requires_grad_(True)

class CastOutputToFloat(nn.Sequential):
  def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

from peft import LoraConfig, get_peft_model 

config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "out_proj", "fc1", "fc2"],
    lora_dropout=0.01,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
print_trainable_parameters(model)


import transformers
from datasets import load_dataset
#data = load_dataset("lambada")
#data = data.map(lambda samples: tokenizer(samples['text']), batched=True)
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)

trainer = transformers.Trainer(
    model=model, 
    train_dataset=data['train'],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4, 
        gradient_accumulation_steps=4,
        warmup_steps=10, 
        max_steps=20, 
        learning_rate=3e-4, 
        fp16=True,
        logging_steps=1, 
        output_dir='outputs'
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()


batch = tokenizer("Two things are infinite: ", return_tensors='pt')

model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
model.eval()
with torch.cuda.amp.autocast():
  output_tokens = model.generate(**batch, max_new_tokens=50)

print('\n\n', tokenizer.decode(output_tokens[0], skip_special_tokens=True))

cc @pacman100

@younesbelkada younesbelkada requested a review from pacman100 March 2, 2023 14:21
@djaym7
Copy link

djaym7 commented Mar 6, 2023

Installed this and it works

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @younesbelkada, thanks a lot for this PR to enable tuning using naive pipeline parallelism across multi-gpu setup. Left a few nits.

return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r > 0 and not self.merged:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
result = result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed I think

@@ -339,17 +344,20 @@ def eval(self):
self.lora_B.eval()

def forward(self, x: torch.Tensor):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed I think

@younesbelkada
Copy link
Contributor Author

Thanks a lot for your review! Fixed the comments! Merging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants